import collections.abc as collections
import numbers
import os
import tempfile
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper  # type: ignore
from typing import Callable, Mapping, Optional, Union

import torch
import torch.nn as nn

import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]


class BaseSaveHandler(metaclass=ABCMeta):
    """Base class for save handlers

    Methods to override:

    - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.__call__`
    - :meth:`~ignite.handlers.checkpoint.BaseSaveHandler.remove`


    Note:
        In derived class, please, make sure that in distributed configuration overridden methods are called by a single
        process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint
        with `xm.save() <https://pytorch.org/xla/release/1.5/index.html#torch_xla.core.xla_model.save>`_  all processes
        should pass into the function. Otherwise, application gets stuck.

    """

    @abstractmethod
    def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
        """Method to save `checkpoint` with `filename`. Additionally, metadata dictionary is provided.

        Metadata contains:

        - `basename`: file prefix (if provided) with checkpoint name, e.g. `epoch_checkpoint`.
        - `score_name`: score name if provided, e.g `val_acc`.
        - `priority`: checkpoint priority value (higher is better), e.g. `12` or `0.6554435`

        Args:
            checkpoint (Mapping): checkpoint dictionary to save.
            filename (str): filename associated with checkpoint.
            metadata (Mapping, optional): metadata on checkpoint to save.

        """
        pass

    @abstractmethod
    def remove(self, filename: str) -> None:
        """Method to remove saved checkpoint.

        Args:
            filename (str): filename associated with checkpoint.

        """
        pass


class Checkpoint(Serializable):
    """Checkpoint handler can be used to periodically save and load objects which have attribute
    ``state_dict`/`load_state_dict``. This class can use specific save handlers to store on the disk or a cloud
    storage, etc. The Checkpoint handler (if used with :class:`~ignite.handlers.DiskSaver`) also handles automatically
    moving data on TPU to CPU before writing the checkpoint.

    Args:
        to_save (Mapping): Dictionary with the objects to save. Objects should have implemented ``state_dict`` and
            ``load_state_dict`` methods. If contains objects of type torch `DistributedDataParallel`_ or
            `DataParallel`_, their internal wrapped model is automatically saved (to avoid additional key ``module.`` in
            the state dictionary).
        save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to
            use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary
            and filename. If ``save_handler`` is callable class, it can
            inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method
            to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk,
            ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`.
        filename_prefix (str, optional): Prefix for the file name to which objects will be saved. See Note for details.
        score_function (callable, optional): If not None, it should be a function taking a single argument,
            :class:`~ignite.engine.engine.Engine` object, and returning a score (`float`). Objects with highest scores
            will be retained.
        score_name (str, optional): If ``score_function`` not None, it is possible to store its value using
            ``score_name``. See Notes for more details.
        n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to
            `None`, all objects are kept.
        global_step_transform (callable, optional): global step transform function to output a desired global step.
            Input of the function is ``(engine, event_name)``. Output of function should be an integer.
            Default is None, global_step based on attached engine. If provided, uses function output as global_step.
            To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`.
        archived (bool, optional): Deprecated argument as models saved by ``torch.save`` are already compressed.
        filename_pattern (str, optional): If ``filename_pattern`` is provided, this pattern will be used to render
            checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for
            details.
        include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
            there must not be another object in ``to_save`` with key ``checkpointer``.

    .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
        torch.nn.parallel.DistributedDataParallel.html
    .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

    Note:
        This class stores a single file as a dictionary of provided objects to save.
        The filename is defined by ``filename_pattern`` and by default has the following
        structure: ``{filename_prefix}_{name}_{suffix}.{ext}`` where

        - ``filename_prefix`` is the argument passed to the constructor,
        - `name` is the key in ``to_save`` if a single object is to store, otherwise `name` is "checkpoint".
        - `suffix` is composed as following ``{global_step}_{score_name}={score}``.

    +----------------+------------+-----------------------+----------------------------------------------+
    | score_function | score_name | global_step_transform |  suffix                                      |
    +================+============+=======================+==============================================+
    |      None      |   None     |        None           | ``{engine.state.iteration}``                 |
    +----------------+------------+-----------------------+----------------------------------------------+
    |       X        |   None     |        None           | ``{score}``                                  |
    +----------------+------------+-----------------------+----------------------------------------------+
    |       X        |   None     |         X             | ``{global_step}_{score}``                    |
    +----------------+------------+-----------------------+----------------------------------------------+
    |       X        |    X       |         X             | ``{global_step}_{score_name}={score}``       |
    +----------------+------------+-----------------------+----------------------------------------------+
    |      None      |   None     |         X             | ``{global_step}``                            |
    +----------------+------------+-----------------------+----------------------------------------------+
    |       X        |    X       |        None           | ``{score_name}={score}``                     |
    +----------------+------------+-----------------------+----------------------------------------------+

    Above `global_step` defined by the output of `global_step_transform` and `score` defined by the output
    of `score_function`.

    By default, none of ``score_function``, ``score_name``, ``global_step_transform`` is defined, then suffix is
    setup by attached engine's current iteration. The filename will be
    `{filename_prefix}_{name}_{engine.state.iteration}.{ext}`.

    For example, ``score_name="neg_val_loss"`` and ``score_function`` that returns `-loss` (as objects with highest
    scores will be retained), then saved filename will be ``{filename_prefix}_{name}_neg_val_loss=-0.1234.pt``.

    Note:
        If ``filename_pattern`` is given, it will be used to render the filenames. ``filename_pattern`` is a string
        that can contain ``{filename_prefix}``, ``{name}``, ``{score}``, ``{score_name}`` and ``{global_step}`` as
        templates.

        For example, let ``filename_pattern="{global_step}-{name}-{score}.pt"`` then the saved filename will be
        ``30000-checkpoint-94.pt``

        **Warning:** Please, keep in mind that if filename collide with already used one to saved a checkpoint,
        new checkpoint will not be stored. This means that filename like ``checkpoint.pt`` will be saved only once
        and will not be overwritten by newer checkpoints.

    Note:
        To get the last stored filename, handler exposes attribute ``last_checkpoint``:

        .. code-block:: python

            handler = Checkpoint(...)
            ...
            print(handler.last_checkpoint)
            > checkpoint_12345.pt

    Note:
        This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
        process. This class supports automatically distributed configuration and if used with
        :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.

    .. warning::

        When running on XLA devices, it should be run in all processes, otherwise application can get stuck on
        saving the checkpoint.

        .. code-block:: python

            # Wrong:
            # if idist.get_rank() == 0:
            #     handler = Checkpoint(...)
            #     trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

            # Correct:
            handler = Checkpoint(...)
            trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

    Examples:

        Attach the handler to make checkpoints during training:

        .. code-block:: python

            from ignite.engine import Engine, Events
            from ignite.handlers import Checkpoint, DiskSaver

            trainer = ...
            model = ...
            optimizer = ...
            lr_scheduler = ...

            to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}
            handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2)
            trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
            trainer.run(data_loader, max_epochs=6)
            > ["checkpoint_7000.pt", "checkpoint_8000.pt", ]

        Attach the handler to an evaluator to save best model during the training
        according to computed validation metric:

        .. code-block:: python

            from ignite.engine import Engine, Events
            from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

            trainer = ...
            evaluator = ...
            # Setup Accuracy metric computation on evaluator
            # Run evaluation on epoch completed event
            # ...

            def score_function(engine):
                return engine.state.metrics['accuracy']

            to_save = {'model': model}
            handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2,
                                 filename_prefix='best', score_function=score_function, score_name="val_acc",
                                 global_step_transform=global_step_from_engine(trainer))

            evaluator.add_event_handler(Events.COMPLETED, handler)

            trainer.run(data_loader, max_epochs=10)
            > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]

    """

    Item = namedtuple("Item", ["priority", "filename"])
    _state_dict_all_req_keys = ("saved",)

    def __init__(
        self,
        to_save: Optional[Mapping],
        save_handler: Union[Callable, BaseSaveHandler],
        filename_prefix: str = "",
        score_function: Optional[Callable] = None,
        score_name: Optional[str] = None,
        n_saved: Optional[int] = 1,
        global_step_transform: Callable = None,
        archived: bool = False,
        filename_pattern: Optional[str] = None,
        include_self: bool = False,
    ):

        if to_save is not None:  # for compatibility with ModelCheckpoint
            if not isinstance(to_save, collections.Mapping):
                raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save)))

            if len(to_save) < 1:
                raise ValueError("No objects to checkpoint.")

            self._check_objects(to_save, "state_dict")

            if include_self:
                if not isinstance(to_save, collections.MutableMapping):
                    raise TypeError(
                        "If `include_self` is True, then `to_save` must be mutable, but given {}.".format(type(to_save))
                    )

                if "checkpointer" in to_save:
                    raise ValueError("Cannot have key 'checkpointer' if `include_self` is True: {}".format(to_save))

        if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)):
            raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler")

        if score_function is None and score_name is not None:
            raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.")

        if global_step_transform is not None and not callable(global_step_transform):
            raise TypeError(
                "global_step_transform should be a function, got {} instead.".format(type(global_step_transform))
            )
        if archived:
            warnings.warn("Argument archived is deprecated and will be removed in 0.5.0")

        self.to_save = to_save
        self.filename_prefix = filename_prefix
        self.save_handler = save_handler
        self.score_function = score_function
        self.score_name = score_name
        self.n_saved = n_saved
        self.ext = "pt"
        self.global_step_transform = global_step_transform
        self.filename_pattern = filename_pattern
        self._saved = []  # type: list
        self.include_self = include_self

    @property
    def last_checkpoint(self) -> Optional[str]:
        if len(self._saved) < 1:
            return None
        return self._saved[-1].filename

    def _check_lt_n_saved(self, or_equal=False):
        if self.n_saved is None:
            return True
        return len(self._saved) < self.n_saved + int(or_equal)

    def __call__(self, engine: Engine) -> None:

        global_step = None
        if self.global_step_transform is not None:
            global_step = self.global_step_transform(engine, engine.last_event_name)

        if self.score_function is not None:
            priority = self.score_function(engine)
            if not isinstance(priority, numbers.Number):
                raise ValueError("Output of score_function should be a number")
        else:
            if global_step is None:
                global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
            priority = global_step

        if self._check_lt_n_saved() or self._saved[0].priority < priority:

            priority_str = (
                "{}".format(priority) if isinstance(priority, numbers.Integral) else "{:.4f}".format(priority)
            )

            checkpoint = self._setup_checkpoint()

            name = "checkpoint"
            if len(checkpoint) == 1:
                for k in checkpoint:
                    name = k
                checkpoint = checkpoint[name]

            if self.filename_pattern is None:
                filename_pattern = self.setup_filename_pattern(
                    with_prefix=len(self.filename_prefix) > 0,
                    with_score=self.score_function is not None,
                    with_score_name=self.score_name is not None,
                    with_global_step=global_step is not None,
                )
            else:
                filename_pattern = self.filename_pattern

            filename_dict = {
                "filename_prefix": self.filename_prefix,
                "ext": self.ext,
                "name": name,
                "score_name": self.score_name,
                "score": priority_str if (self.score_function is not None) else None,
                "global_step": global_step,
            }
            filename = filename_pattern.format(**filename_dict)

            if any(item.filename == filename for item in self._saved):
                return

            metadata = {
                "basename": "{}{}{}".format(self.filename_prefix, "_" * int(len(self.filename_prefix) > 0), name),
                "score_name": self.score_name,
                "priority": priority,
            }

            if not self._check_lt_n_saved():
                item = self._saved.pop(0)
                if isinstance(self.save_handler, BaseSaveHandler):
                    self.save_handler.remove(item.filename)

            self._saved.append(Checkpoint.Item(priority, filename))
            self._saved.sort(key=lambda item: item[0])

            if self.include_self:
                # Now that we've updated _saved, we can add our own state_dict.
                checkpoint["checkpointer"] = self.state_dict()

            try:
                self.save_handler(checkpoint, filename, metadata)
            except TypeError:
                self.save_handler(checkpoint, filename)

    def _setup_checkpoint(self) -> dict:
        checkpoint = {}
        if self.to_save is not None:
            for k, obj in self.to_save.items():
                if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
                    obj = obj.module
                checkpoint[k] = obj.state_dict()
        return checkpoint

    @staticmethod
    def setup_filename_pattern(
        with_prefix: bool = True, with_score: bool = True, with_score_name: bool = True, with_global_step: bool = True,
    ) -> str:
        """Helper method to get the default filename pattern for a checkpoint.

        Args:
            with_prefix (bool): If True, the ``filename_prefix`` is added to the filename pattern:
                ``{filename_prefix}_{name}...``. Default, True.
            with_score (bool): If True, ``score`` is added to the filename pattern: ``..._{score}.{ext}``.
                Default, True. At least one of ``with_score`` and ``with_global_step`` should be True.
            with_score_name (bool): If True, ``score_name`` is added to the filename pattern:
                ``..._{score_name}={score}.{ext}``. If activated, argument ``with_score`` should be
                also True, otherwise an error is raised. Default, True.
            with_global_step (bool): If True, ``{global_step}`` is added to the
                filename pattern: ``...{name}_{global_step}...``.
                At least one of ``with_score`` and ``with_global_step`` should be True.

        Example:

            .. code-block:: python

                from ignite.handlers import Checkpoint

                filename_pattern = Checkpoint.setup_filename_pattern()

                print(filename_pattern)
                > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"
        """
        filename_pattern = "{name}"

        if not (with_global_step or with_score):
            raise ValueError("At least one of with_score and with_global_step should be True.")

        if with_global_step:
            filename_pattern += "_{global_step}"

        if with_score_name and with_score:
            filename_pattern += "_{score_name}={score}"
        elif with_score:
            filename_pattern += "_{score}"
        elif with_score_name:
            raise ValueError("If with_score_name is True, with_score should be also True")

        if with_prefix:
            filename_pattern = "{filename_prefix}_" + filename_pattern

        filename_pattern += ".{ext}"
        return filename_pattern

    @staticmethod
    def _check_objects(objs: Mapping, attr: str) -> None:
        for k, obj in objs.items():
            if not hasattr(obj, attr):
                raise TypeError("Object {} should have `{}` method".format(type(obj), attr))

    @staticmethod
    def load_objects(to_load: Mapping, checkpoint: Mapping, **kwargs) -> None:
        """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.

        Exemples:

        .. code-block:: python

            import torch
            from ignite.engine import Engine, Events
            from ignite.handlers import ModelCheckpoint, Checkpoint
            trainer = Engine(lambda engine, batch: None)
            handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
            model = torch.nn.Linear(3, 3)
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
            to_save = {"weights": model, "optimizer": optimizer}
            trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
            trainer.run(torch.randn(10, 1), 5)

            to_load = to_save
            checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
            checkpoint = torch.load(checkpoint_fp)
            Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

        Note:
            If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
            `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).

        Args:
            to_load (Mapping): a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
            checkpoint (Mapping): a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
                "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain directly
                corresponding state_dict.
            **kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
                the user to load part of the pretrained model (useful for example, in Transfer Learning)

        .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
            torch.nn.parallel.DistributedDataParallel.html
        .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html

        """
        Checkpoint._check_objects(to_load, "load_state_dict")
        if not isinstance(checkpoint, collections.Mapping):
            raise TypeError("Argument checkpoint should be a dictionary, but given {}".format(type(checkpoint)))

        if len(kwargs) > 1 or any(k for k in kwargs.keys() if k not in ["strict"]):
            warnings.warn("kwargs contains keys other than strict and these will be ignored")

        is_state_dict_strict = kwargs.get("strict", True)
        if len(to_load) == 1:
            # single object and checkpoint is directly a state_dict
            key, obj = list(to_load.items())[0]
            if key not in checkpoint:
                if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
                    obj = obj.module
                obj.load_state_dict(checkpoint, strict=is_state_dict_strict)
                return

        # multiple objects to load
        for k, obj in to_load.items():
            if k not in checkpoint:
                raise ValueError("Object labeled by '{}' from `to_load` is not found in the checkpoint".format(k))
            if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
                obj = obj.module
            if isinstance(obj, torch.nn.Module):
                obj.load_state_dict(checkpoint[k], strict=is_state_dict_strict)
            else:
                obj.load_state_dict(checkpoint[k])

    def state_dict(self) -> OrderedDict:
        return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

    def load_state_dict(self, state_dict: Mapping) -> None:
        super().load_state_dict(state_dict)
        self._saved = [Checkpoint.Item(p, f) for p, f in state_dict["saved"]]


class DiskSaver(BaseSaveHandler):
    """Handler that saves input checkpoint on a disk.

    Args:
        dirname (str): Directory path where the checkpoint will be saved
        atomic (bool, optional): if True, checkpoint is serialized to a temporary file, and then
            moved to final destination, so that files are guaranteed to not be damaged
            (for example if exception occurs during saving).
        create_dir (bool, optional): if True, will create directory ``dirname`` if it doesnt exist.
        require_empty (bool, optional): If True, will raise exception if there are any files in the
            directory ``dirname``.
        **kwargs: Accepted keyword arguments for `torch.save` or `xm.save`.
    """

    def __init__(
        self, dirname: str, atomic: bool = True, create_dir: bool = True, require_empty: bool = True, **kwargs
    ):
        self.dirname = os.path.expanduser(dirname)
        self._atomic = atomic
        self._check_and_setup(dirname, create_dir, require_empty)
        self.kwargs = kwargs

    @staticmethod
    @idist.one_rank_only()
    def _check_and_setup(dirname, create_dir, require_empty):
        if create_dir:
            if not os.path.exists(dirname):
                os.makedirs(dirname)
        # Ensure that dirname exists
        if not os.path.exists(dirname):
            raise ValueError("Directory path '{}' is not found".format(dirname))

        if require_empty:
            matched = [fname for fname in os.listdir(dirname) if fname.endswith(".pt")]
            if len(matched) > 0:
                raise ValueError(
                    "Files {} with extension '.pt' are already present "
                    "in the directory {}. If you want to use this "
                    "directory anyway, pass `require_empty=False`."
                    "".format(matched, dirname)
                )

    def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
        path = os.path.join(self.dirname, filename)

        if idist.has_xla_support:
            self._save_xla(checkpoint, path)
        else:
            self._save_native(checkpoint, path)

    @idist.one_rank_only()
    def _save_native(self, checkpoint: Mapping, path: str):
        self._save_func(checkpoint, path, torch.save)

    def _save_xla(self, checkpoint: Mapping, path: str):
        import torch_xla.core.xla_model as xm  # type: ignore

        # all tpu procs should enter here as internally performs sync across device
        self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())

    def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0):
        if not self._atomic:
            func(checkpoint, path, **self.kwargs)
        else:
            tmp_file = None
            tmp_name = ""
            tmp = None  # type: _TemporaryFileWrapper
            if rank == 0:
                tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
                tmp_file = tmp.file
                tmp_name = tmp.name
            try:
                func(checkpoint, tmp_file, **self.kwargs)
            except BaseException:
                if tmp is not None:
                    tmp.close()
                    os.remove(tmp_name)
                    raise
            else:
                if tmp is not None:
                    tmp.close()
                    os.rename(tmp.name, path)

    @idist.one_rank_only()
    def remove(self, filename: str) -> None:
        path = os.path.join(self.dirname, filename)
        os.remove(path)


class ModelCheckpoint(Checkpoint):
    """ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to
    another storage type, please consider :class:`~ignite.handlers.checkpoint.Checkpoint`.

    This handler expects two arguments:

        - an :class:`~ignite.engine.engine.Engine` object
        - a `dict` mapping names (`str`) to objects that should be saved to disk.

    See Examples for further details.

    .. warning::

        Behaviour of this class has been changed since v0.3.0.

        Argument ``save_as_state_dict`` is deprecated and should not be used. It is considered as True.

        Argument ``save_interval`` is deprecated and should not be used. Please, use events filtering instead, e.g.
        :attr:`~ignite.engine.events.Events.ITERATION_STARTED(every=1000)`

        There is no more internal counter that has been used to indicate the number of save actions. User could
        see its value `step_number` in the filename, e.g. `{filename_prefix}_{name}_{step_number}.pt`. Actually,
        `step_number` is replaced by current engine's epoch if `score_function` is specified and current iteration
        otherwise.

        A single `pt` file is created instead of multiple files.

    Args:
        dirname (str): Directory path where objects will be saved.
        filename_prefix (str): Prefix for the file names to which objects will be saved. See Notes of
            :class:`~ignite.handlers.Checkpoint` for more details.
        score_function (callable, optional): if not None, it should be a function taking a single argument, an
            :class:`~ignite.engine.engine.Engine` object, and return a score (`float`). Objects with highest scores
            will be retained.
        score_name (str, optional): if ``score_function`` not None, it is possible to store its value using
            `score_name`. See Notes for more details.
        n_saved (int, optional): Number of objects that should be kept on disk. Older files will be removed. If set to
            `None`, all objects are kept.
        atomic (bool, optional): If True, objects are serialized to a temporary file, and then moved to final
            destination, so that files are guaranteed to not be damaged (for example if exception
            occurs during saving).
        require_empty (bool, optional): If True, will raise exception if there are any files starting with
            ``filename_prefix`` in the directory ``dirname``.
        create_dir (bool, optional): If True, will create directory ``dirname`` if it does not exist.
        global_step_transform (callable, optional): global step transform function to output a desired global step.
            Input of the function is `(engine, event_name)`. Output of function should be an integer.
            Default is None, global_step based on attached engine. If provided, uses function output as global_step.
            To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`.
        archived (bool, optional): Deprecated argument as models saved by `torch.save` are already compressed.
        include_self (bool): Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
            there must not be another object in ``to_save`` with key ``checkpointer``.
        **kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.

    Examples:
        >>> import os
        >>> from ignite.engine import Engine, Events
        >>> from ignite.handlers import ModelCheckpoint
        >>> from torch import nn
        >>> trainer = Engine(lambda engine, batch: None)
        >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True)
        >>> model = nn.Linear(3, 3)
        >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model})
        >>> trainer.run([0], max_epochs=6)
        >>> os.listdir('/tmp/models')
        ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt']
        >>> handler.last_checkpoint
        ['/tmp/models/myprefix_mymodel_6.pt']
    """

    def __init__(
        self,
        dirname: str,
        filename_prefix: str,
        save_interval: Optional[Callable] = None,
        score_function: Optional[Callable] = None,
        score_name: Optional[str] = None,
        n_saved: Union[int, None] = 1,
        atomic: bool = True,
        require_empty: bool = True,
        create_dir: bool = True,
        save_as_state_dict: bool = True,
        global_step_transform: Optional[Callable] = None,
        archived: bool = False,
        include_self: bool = False,
        **kwargs
    ):

        if not save_as_state_dict:
            raise ValueError(
                "Argument save_as_state_dict is deprecated and should be True."
                "This argument will be removed in 0.5.0."
            )
        if save_interval is not None:
            msg = (
                "Argument save_interval is deprecated and should be None. This argument will be removed in 0.5.0."
                "Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000)"
            )
            if save_interval == 1:
                # Do not break for old version who used `save_interval=1`
                warnings.warn(msg)
            else:
                # No choice
                raise ValueError(msg)

        disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)

        super(ModelCheckpoint, self).__init__(
            to_save=None,
            save_handler=disk_saver,
            filename_prefix=filename_prefix,
            score_function=score_function,
            score_name=score_name,
            n_saved=n_saved,
            global_step_transform=global_step_transform,
            archived=archived,
            include_self=include_self,
        )

    @property
    def last_checkpoint(self) -> Union[str, None]:
        if len(self._saved) < 1:
            return None

        if not isinstance(self.save_handler, DiskSaver):
            raise RuntimeError(
                "Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler))
            )

        return os.path.join(self.save_handler.dirname, self._saved[-1].filename)

    def __call__(self, engine: Engine, to_save: Mapping) -> None:  # type: ignore

        if len(to_save) == 0:
            raise RuntimeError("No objects to checkpoint found.")

        self._check_objects(to_save, "state_dict")
        self.to_save = to_save
        super(ModelCheckpoint, self).__call__(engine)
